import numpy as np
import gymnasium as gym
from Environment.environment import Environment
from Environment.Environments.ACDomains.ac_domain import ACDomain, ACObject
from collections.abc import Iterable
import itertools

map_variants = {
    "map4x2": [3, [[[0,1],2, [
        [0,0,0],
        [1,0,1],
        [2,0,2],
        [3,0,0],
        [0,1,1],
        [1,1,2],
        [2,1,0],
        [3,1,1],
    ]]], [4,2,3]]
} # number of objects, [parents, targets, relation_type, hyperparams], maxval
# maxval can be an array


def create_map_relation(names, modval, mapping):
    map_dict = {
        tuple(m[:-1]): m[-1] for m in mapping
    }
    def map_relation(objects):
        objects[names[1]].attribute = (map_dict[tuple([objects[name].attribute for name in names[0]])]) % modval
    return map_relation

def compute_all_mapping_tables(env):
    state_table = [np.arange(v) for v in env.maxvals[:-1]]
    state_table = np.array(np.meshgrid(*state_table)).T.reshape(-1,env.num - 1)
    outcome_possibles = list(itertools.product(np.arange(env.maxvals[-1]), repeat=len(state_table)))
    # print(state_table)
    # print(len(outcome_possibles))
    all_state_tables = list()
    for outcome in outcome_possibles:
        new_state_table = np.concatenate([state_table, np.expand_dims(outcome, -1)], axis=-1)
        all_state_tables.append(new_state_table)
    return all_state_tables

class MapDAG(ACDomain):
    def __init__(self, frameskip = 1, variant="", fixed_limits=False, cf_states=False, mapping=None):
        num, relations, maxval = map_variants[variant]
        if mapping is not None: # replace the mappings with the given mappings
            # print(relations, mapping)
            relations[0][-1] = mapping
        
        self.all_names = [self.convert_idx_to_name(i) for i in range(num)]
        self.maxval_dict = {self.all_names[i]: maxval[i] for i in range(len(maxval))} if isinstance(maxval, Iterable) else {n: maxval for n in self.all_names} 
        self.objects = {n: ACObject(n, self.maxval_dict[n]) for n in self.all_names} # dict of name to value
        self.maxvals = maxval
        self.relations = relations
        self.num = num
        self.binary_relations = [self.create_relation(*rel) for rel in relations] # must get set prior to calling super (), the order follows the order of operations
        self.relation_outcome = [self.all_names[rel[1]] for rel in relations]
        self.passive_mask = np.zeros(len(self.all_names)-1)
        self.outcome_variable = self.all_names[-1]
        super().__init__(frameskip, variant, fixed_limits, cf_states=cf_states)

    def convert_idx_to_name(self, idx):
        return chr(ord('@')+ idx + 1)

    def create_relation(self, parents, target, mapping):
        # print(parents, self.all_names[parents[0]])
        names = ([self.all_names[p] for p in parents], self.all_names[target]) 
        relation = create_map_relation(names, self.maxval_dict[self.all_names[target]], mapping)
        return relation
